import os.path  # 用于管理路径
import sys  # 用于在argvTo[0]中找到脚本名称
import pandas as pd
import time
import pandas_ta as ta

# 获取当前目录
proj_path = os.path.dirname(os.path.abspath(sys.argv[0])) + '/../'

g_ma_list = [5, 10, 20, 30, 60, 120, 250]
g_vol_ma_list = [5, 10, 135]
g_shift_n = 5
g_ml_min_period = 1500


# macd
def MACD(df, n_fast, n_slow, ksgn='close'):
    xnam = 'mdiff'  # 'macd'
    xnam2 = 'mdea'  # 'msign'
    xnam3 = 'macd'  # 'mdiff'
    EMAfast = df[ksgn].ewm(span=n_fast, min_periods=n_fast - 1).mean()
    EMAslow = df[ksgn].ewm(span=n_slow, min_periods=n_slow - 1).mean()
    mdiff = pd.Series(EMAfast - EMAslow, name=xnam)  # dif
    xnum = max(int((n_fast + n_slow) / 4), 2)
    mdea = mdiff.ewm(span=xnum, min_periods=xnum - 1).mean()  # DEA or DEM
    mdea.name = xnam2
    macd = pd.Series(mdiff - mdea, name=xnam3).map(lambda x: x * 2)
    df = df.join(macd)
    df = df.join(mdea)
    df = df.join(mdiff)
    return df


# 均线
def MA_n(df, n, ksgn='close'):
    xnam = '{}ma_{}'.format('' if 'close' == ksgn else ksgn + '_', n)
    ds2 = pd.Series(df[ksgn], name=xnam, index=df.index)
    ds5 = ds2.rolling(center=False, window=n).mean()
    df = df.join(ds5)
    return df


# macd指标中，前n段中，红色、绿色柱面积
def macd_ext(df, n):
    df['macd_1a'] = df[['macd']].shift(1)
    df['macd_switch'] = df.apply(
        lambda x: 1 if x.macd > 0 and x.macd_1a < 0 else (
            -1 if x.macd < 0 and x.macd_1a > 0 else 0), axis=1
    )

    red = []
    green = []
    # 深拷贝
    for i in range(n):
        red.append([0.0] * df.shape[0])
        green.append([0.0] * df.shape[0])

    curr_red = [0.0] * n
    curr_green = [0.0] * n
    accu_value = 0

    for i in range(df.shape[0]):
        if pd.isna(df['macd'].iloc[i]):
            continue
        if 1 == df['macd_switch'].iloc[i]:
            for j in range(n - 1, 0, -1):
                curr_green[j] = curr_green[j - 1]
            curr_green[0] = accu_value
            accu_value = df['macd'].iloc[i]
        elif -1 == df['macd_switch'].iloc[i]:
            for j in range(n - 1, 0, -1):
                curr_red[j] = curr_red[j - 1]
            curr_red[0] = accu_value
            accu_value = df['macd'].iloc[i]
        else:
            accu_value += df['macd'].iloc[i]
        for j in range(n):
            red[j][i] = curr_red[j]
            green[j][i] = curr_green[j]

    for i in range(n):
        temp_series = pd.Series(red[i], name='red{}'.format(i))
        temp_series.index = df.index
        df = df.join(temp_series)

        temp_series = pd.Series(green[i], name='green{}'.format(i))
        temp_series.index = df.index
        df = df.join(temp_series)

    return df


# 缩量阴线，前1日暴涨
def shrink_negative_line(df):
    df['shrink_negative_line'] = df.apply(
        lambda x: 1 if ((x.close_1 - x.close_2) / x.close_2) > 0.09 and \
                       x.volume < x.volume_1 and \
                       x.close < x.open and \
                       x.low > x.low_1 and \
                       x.close < x.close_1 else 0, axis=1
    )
    return df


# 缩量
def shrink_volume(df):
    df['shrink_volume'] = df.apply(
        lambda x: 1 if x.volume < x.volume_1a else 0, axis=1
    )
    return df


# 暴量，成交量大于135日均量线
def volume_boom(df):
    df['volume_boom'] = df.apply(
        lambda x: 1 if x.volume > x.volume_ma_135 else 0, axis=1)
    return df


# 暴涨，涨幅大于9%
def value_boom(df):
    df['value_boom'] = df.apply(
        lambda x: 1 if (x.close - x.close_1a) / x.close_1a > 0.09 else 0, axis=1)
    return df


# 底分型
def bottom_shape(df):
    df['bottom_shape'] = df.apply(
        lambda x: 1 if x.low_1a < x.low_2a and x.low_1a < x.low and x.high_1a < x.high_2a and x.high_1a < x.high else 0,
        axis=1)
    return df


# 基于异动量计算异动量收复
def retrieve_special_volume(df):
    # 按条件生成新列
    df['retrieve_special_volume'] = df.apply(
        lambda x: 1 if 1 == x.special_volume_1a and x.close > x.high_1a and x.close > x.open else 0, axis=1)
    return df


# 阳线
def positive(df):
    df['positive'] = df.apply(
        lambda x: 1 if x.close > x.open else 0, axis=1
    )
    return df


# 阴线
def negative(df):
    df['negative'] = df.apply(
        lambda x: 1 if x.close < x.open else 0, axis=1
    )
    return df


# 异动量
def special_volume(df):
    # 按条件生成新列
    df['special_volume'] = df.apply(
        lambda x: 1 if x.open > x.close and x.close < x.close_1a and x.volume > x.volume_1a else 0, axis=1)
    return df


# 将前n日的指标列入当日指标
def shift_till_n(df, indicator_list, n):
    for i in range(n):
        shift_i(df, indicator_list, i + 1)
    return df


# 将第前n日的指标列入当日指标
def shift_i(df, indicator_list, i):
    for ind in indicator_list:
        df['{}_{}a'.format(ind, i)] = df[ind].shift(i)
    return df


# 计算最大收益
def max_profit(x, percent_change=0.1):
    ret = 0
    if (max(x) - x.iloc[-1]) / x.iloc[-1] >= percent_change:
        ret = 1
    return ret


# 计算是否能够在days日内的实现收益percent_change
def class_label(df, days, percent_change):
    df['label_{}_{}%'.format(days, percent_change * 100)] = (
                                                                df.iloc[::-1]['close'].rolling(days + 1).apply(
                                                                    max_profit,
                                                                    kwargs={'percent_change': percent_change})).iloc[
                                                            ::-1]
    return df


if __name__ == '__main__':
    # 程序开始时的时间
    time_start = time.time()

    # 机器学习
    stock_code_file = proj_path + 'data/tdx/ml_stock_code.csv'
    if not os.path.exists(stock_code_file):
        all_stock_code_file = proj_path + 'data/tdx/all_stock_codes.csv'
        stock_codes = pd.read_csv(all_stock_code_file, encoding='unicode_escape')
        ml_stock_list = []
        # 筛选股票，确保有充足的训练数据
        for code in stock_codes['code']:
            input_file = proj_path + 'data/tdx/day/' + code + '.csv'
            if not os.path.exists(input_file):
                continue
            df = pd.read_csv(input_file)
            if df.shape[0] > g_ml_min_period:
                ml_stock_list.append(code)
        out_df = pd.DataFrame(ml_stock_list, columns=['code'])
        out_df.to_csv(stock_code_file, index=False)
    stock_codes = pd.read_csv(stock_code_file, encoding='unicode_escape')

    # 创建写出目录
    out_dir = proj_path + 'data/extension/d/ml/'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # 循环处理每只股票
    for code in stock_codes['code']:
        print('processing {}...'.format(code))
        input_file = proj_path + 'data/tdx/day/' + code + '.csv'
        if not os.path.exists(input_file):
            continue
        output_file = out_dir + code + '.csv'
        exist_df = pd.DataFrame()
        df = pd.read_csv(input_file)
        df = df.sort_index(ascending=True)
        # 用于更新数据时，减少计算规模

        df.ta.strategy(exclude=['dpo', 'psar', 'supertrend', 'ichimoku', 'hilo'], verbose=True, timed=True)

        # vol_MA
        for i in g_vol_ma_list:
            df = MA_n(df, i, 'volume')

        # ma
        for i in g_ma_list:
            df = MA_n(df, i)

        # 计算回滚参数
        indicator_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
        indicator_list.extend(list(map(lambda x: 'ma_' + str(x), g_ma_list)))
        indicator_list.extend(list(map(lambda x: 'volume_ma_' + str(x), g_vol_ma_list)))
        df = shift_till_n(df, indicator_list, g_shift_n)

        # 计算异动量
        df = special_volume(df)
        df = shift_till_n(df, ['special_volume'], g_shift_n)

        # 异动量收复
        df = retrieve_special_volume(df)

        # 底分型
        df = bottom_shape(df)

        # MACD
        df = MACD(df, 12, 26)
        df = macd_ext(df, 3)

        # 计算暴涨
        df = value_boom(df)
        df = shift_till_n(df, ['value_boom'], g_shift_n)

        # 计算量暴涨
        df = volume_boom(df)
        df = shift_till_n(df, ['volume_boom'], g_shift_n)

        # 计算缩量
        df = shrink_volume(df)
        df = shift_till_n(df, ['shrink_volume'], g_shift_n)
        # df = shrink_negative_line(df)

        # 计算阳线、阴线
        df = positive(df)
        df = negative(df)
        df = shift_till_n(df, ['positive', 'negative'], g_shift_n)


        # 计算分类标准
        df = class_label(df, 1, 0.095)
        df = class_label(df, 2, 0.095)
        df = class_label(df, 5, 0.095)
        df = class_label(df, 10, 0.095)
        df = class_label(df, 2, 0.195)
        df = class_label(df, 5, 0.195)
        df = class_label(df, 10, 0.195)


        # 写出文件
        df.to_csv(output_file, index=False)
        print(code + ' done!')

    # 程序结束时系统时间
    time_end = time.time()

    print('程序所耗时间：', time_end - time_start)
